{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This is a sample implementation of Tree SHAP written in Python for easy reading."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import sklearn.ensemble\n",
    "import shap\n",
    "import numpy as np\n",
    "import numba\n",
    "import time"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Load boston dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(506, 13)"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "X,y = shap.datasets.boston()\n",
    "X.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Train sklearn random forest"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "RandomForestRegressor(bootstrap=True, criterion='mse', max_depth=4,\n",
       "           max_features='auto', max_leaf_nodes=None,\n",
       "           min_impurity_decrease=0.0, min_impurity_split=None,\n",
       "           min_samples_leaf=1, min_samples_split=2,\n",
       "           min_weight_fraction_leaf=0.0, n_estimators=1000, n_jobs=1,\n",
       "           oob_score=False, random_state=None, verbose=0, warm_start=False)"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model = sklearn.ensemble.RandomForestRegressor(n_estimators=1000, max_depth=4)\n",
    "model.fit(X, y)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Python TreeExplainer\n",
    "\n",
    "This uses numba to speed things up."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "class TreeExplainer:\n",
    "    def __init__(self, model, **kwargs):\n",
    "        \n",
    "        if str(type(model)).endswith(\"sklearn.ensemble.forest.RandomForestRegressor'>\"):\n",
    "            self.trees = [Tree(e.tree_) for e in model.estimators_]\n",
    "            \n",
    "        # Preallocate space for the unique path data\n",
    "        maxd = np.max([t.max_depth for t in self.trees]) + 2\n",
    "        print(maxd)\n",
    "        s = (maxd * (maxd + 1)) // 2\n",
    "        self.feature_indexes = np.zeros(s, dtype=np.int32)\n",
    "        self.zero_fractions = np.zeros(s, dtype=np.float64)\n",
    "        self.one_fractions = np.zeros(s, dtype=np.float64)\n",
    "        self.pweights = np.zeros(s, dtype=np.float64)\n",
    "\n",
    "    def shap_values(self, X, **kwargs):\n",
    "        # convert dataframes\n",
    "        if str(type(X)).endswith(\"pandas.core.series.Series'>\"):\n",
    "            X = X.values\n",
    "        elif str(type(X)).endswith(\"'pandas.core.frame.DataFrame'>\"):\n",
    "            X = X.values\n",
    "\n",
    "        assert str(type(X)).endswith(\"'numpy.ndarray'>\"), \"Unknown instance type: \" + str(type(X))\n",
    "        assert len(X.shape) == 1 or len(X.shape) == 2, \"Instance must have 1 or 2 dimensions!\"\n",
    "\n",
    "        # single instance\n",
    "        if len(X.shape) == 1:\n",
    "            phi = np.zeros(X.shape[0] + 1)\n",
    "            x_missing = np.zeros(X.shape[0], dtype=bool)\n",
    "            for t in self.trees:\n",
    "                self.tree_shap(t, X, x_missing, phi)\n",
    "            phi /= len(self.trees)\n",
    "        elif len(X.shape) == 2:\n",
    "            phi = np.zeros((X.shape[0], X.shape[1] + 1))\n",
    "            x_missing = np.zeros(X.shape[1], dtype=bool)\n",
    "            for i in range(X.shape[0]):\n",
    "                for t in self.trees:\n",
    "                    self.tree_shap(t, X[i,:], x_missing, phi[i,:])\n",
    "            phi /= len(self.trees)\n",
    "        return phi\n",
    "    \n",
    "    def tree_shap(self, tree, x, x_missing, phi, condition=0, condition_feature=0):\n",
    "\n",
    "        # update the bias term, which is the last index in phi\n",
    "        # (note the paper has this as phi_0 instead of phi_M)\n",
    "        if condition == 0:\n",
    "            phi[-1] += tree.values[0]\n",
    "\n",
    "        # start the recursive algorithm\n",
    "        tree_shap_recursive(\n",
    "            tree.children_left, tree.children_right, tree.children_default, tree.features,\n",
    "            tree.thresholds, tree.values, tree.node_sample_weight,\n",
    "            x, x_missing, phi, 0, 0, self.feature_indexes, self.zero_fractions, self.one_fractions, self.pweights,\n",
    "            1, 1, -1, condition, condition_feature, 1\n",
    "        )\n",
    "        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "collapsed": true,
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "# extend our decision path with a fraction of one and zero extensions\n",
    "@numba.jit(\n",
    "    numba.types.void(\n",
    "        numba.types.int32[:],\n",
    "        numba.types.float64[:],\n",
    "        numba.types.float64[:],\n",
    "        numba.types.float64[:],\n",
    "        numba.types.int32,\n",
    "        numba.types.float64,\n",
    "        numba.types.float64,\n",
    "        numba.types.int32\n",
    "    ), nopython=True, nogil=True\n",
    ")\n",
    "def extend_path(feature_indexes, zero_fractions, one_fractions, pweights,\n",
    "                unique_depth, zero_fraction, one_fraction, feature_index):\n",
    "    feature_indexes[unique_depth] = feature_index\n",
    "    zero_fractions[unique_depth] = zero_fraction\n",
    "    one_fractions[unique_depth] = one_fraction\n",
    "    if unique_depth == 0: \n",
    "        pweights[unique_depth] = 1\n",
    "    else:\n",
    "        pweights[unique_depth] = 0\n",
    "    \n",
    "    for i in range(unique_depth - 1, -1, -1):\n",
    "        pweights[i+1] += one_fraction * pweights[i] * (i + 1) / (unique_depth + 1)\n",
    "        pweights[i] = zero_fraction * pweights[i] * (unique_depth - i) / (unique_depth + 1)\n",
    "        \n",
    "# undo a previous extension of the decision path\n",
    "@numba.jit(\n",
    "    numba.types.void(\n",
    "        numba.types.int32[:],\n",
    "        numba.types.float64[:],\n",
    "        numba.types.float64[:],\n",
    "        numba.types.float64[:],\n",
    "        numba.types.int32,\n",
    "        numba.types.int32\n",
    "    ), nopython=True, nogil=True\n",
    ")\n",
    "def unwind_path(feature_indexes, zero_fractions, one_fractions, pweights,\n",
    "                unique_depth, path_index):\n",
    "    one_fraction = one_fractions[path_index]\n",
    "    zero_fraction = zero_fractions[path_index]\n",
    "    next_one_portion = pweights[unique_depth]\n",
    "\n",
    "    for i in range(unique_depth - 1, -1, -1):\n",
    "        if one_fraction != 0:\n",
    "            tmp = pweights[i]\n",
    "            pweights[i] = next_one_portion * (unique_depth + 1) / ((i + 1) * one_fraction)\n",
    "            next_one_portion = tmp - pweights[i] * zero_fraction * (unique_depth - i) / (unique_depth + 1)\n",
    "        else:\n",
    "            pweights[i] = (pweights[i] * (unique_depth + 1)) / (zero_fraction * (unique_depth - i))\n",
    "\n",
    "    for i in range(path_index, unique_depth):\n",
    "        feature_indexes[i] = feature_indexes[i+1]\n",
    "        zero_fractions[i] = zero_fractions[i+1]\n",
    "        one_fractions[i] = one_fractions[i+1]\n",
    "        \n",
    "# determine what the total permuation weight would be if\n",
    "# we unwound a previous extension in the decision path\n",
    "@numba.jit(\n",
    "    numba.types.float64(\n",
    "        numba.types.int32[:],\n",
    "        numba.types.float64[:],\n",
    "        numba.types.float64[:],\n",
    "        numba.types.float64[:],\n",
    "        numba.types.int32,\n",
    "        numba.types.int32\n",
    "    ),\n",
    "    nopython=True, nogil=True\n",
    ")\n",
    "def unwound_path_sum(feature_indexes, zero_fractions, one_fractions, pweights, unique_depth, path_index):\n",
    "    one_fraction = one_fractions[path_index]\n",
    "    zero_fraction = zero_fractions[path_index]\n",
    "    next_one_portion = pweights[unique_depth]\n",
    "    total = 0\n",
    "    \n",
    "    for i in range(unique_depth - 1, -1, -1):\n",
    "        if one_fraction != 0:\n",
    "            tmp = next_one_portion * (unique_depth + 1) / ((i + 1) * one_fraction)\n",
    "            total += tmp;\n",
    "            next_one_portion = pweights[i] - tmp * zero_fraction * ((unique_depth - i) / (unique_depth + 1))\n",
    "        else:\n",
    "            total += (pweights[i] / zero_fraction) / ((unique_depth - i) / (unique_depth + 1))\n",
    "\n",
    "    return total\n",
    "\n",
    "\n",
    "class Tree:\n",
    "    def __init__(self, children_left, children_right, children_default, feature, threshold, value, node_sample_weight):\n",
    "        self.children_left = children_left.astype(np.int32)\n",
    "        self.children_right = children_right.astype(np.int32)\n",
    "        self.children_default = children_default.astype(np.int32)\n",
    "        self.features = feature.astype(np.int32)\n",
    "        self.thresholds = threshold\n",
    "        self.values = value\n",
    "        self.node_sample_weight = node_sample_weight\n",
    "        \n",
    "        self.max_depth = compute_expectations(\n",
    "            self.children_left, self.children_right, self.node_sample_weight,\n",
    "            self.values, 0\n",
    "        )\n",
    "    \n",
    "    def __init__(self, tree):\n",
    "        if str(type(tree)).endswith(\"'sklearn.tree._tree.Tree'>\"):\n",
    "            self.children_left = tree.children_left.astype(np.int32)\n",
    "            self.children_right = tree.children_right.astype(np.int32)\n",
    "            self.children_default = self.children_left\n",
    "            if hasattr(tree, \"missing_go_to_left\"):\n",
    "                self.children_default = np.where(tree.missing_go_to_left, self.children_left, self.children_right)                \n",
    "            self.features = tree.feature.astype(np.int32)\n",
    "            self.thresholds = tree.threshold.astype(np.float64)\n",
    "            self.values = tree.value[:,0,0] # assume only a single output for now\n",
    "            self.node_sample_weight = tree.weighted_n_node_samples.astype(np.float64)\n",
    "            \n",
    "            # we recompute the expectations to make sure they follow the SHAP logic\n",
    "            self.max_depth = compute_expectations(\n",
    "                self.children_left, self.children_right, self.node_sample_weight,\n",
    "                self.values, 0\n",
    "            )\n",
    "\n",
    "@numba.jit(nopython=True)\n",
    "def compute_expectations(children_left, children_right, node_sample_weight, values, i, depth=0):\n",
    "    if children_right[i] == -1:\n",
    "        values[i] = values[i]\n",
    "        return 0\n",
    "    else:\n",
    "        li = children_left[i]\n",
    "        ri = children_right[i]\n",
    "        depth_left = compute_expectations(children_left, children_right, node_sample_weight, values, li, depth + 1)\n",
    "        depth_right = compute_expectations(children_left, children_right, node_sample_weight, values, ri, depth + 1)\n",
    "        left_weight = node_sample_weight[li]\n",
    "        right_weight = node_sample_weight[ri]\n",
    "        v = (left_weight * values[li] + right_weight * values[ri]) / (left_weight + right_weight)\n",
    "        values[i] = v\n",
    "        return max(depth_left, depth_right) + 1\n",
    "        \n",
    "# recursive computation of SHAP values for a decision tree\n",
    "@numba.jit(\n",
    "    numba.types.void(\n",
    "        numba.types.int32[:],\n",
    "        numba.types.int32[:],\n",
    "        numba.types.int32[:],\n",
    "        numba.types.int32[:],\n",
    "        numba.types.float64[:],\n",
    "        numba.types.float64[:],\n",
    "        numba.types.float64[:],\n",
    "        numba.types.float64[:],\n",
    "        numba.types.boolean[:],\n",
    "        numba.types.float64[:],\n",
    "        numba.types.int64,\n",
    "        numba.types.int64,\n",
    "        numba.types.int32[:],\n",
    "        numba.types.float64[:],\n",
    "        numba.types.float64[:],\n",
    "        numba.types.float64[:],\n",
    "        numba.types.float64,\n",
    "        numba.types.float64,\n",
    "        numba.types.int64,\n",
    "        numba.types.int64,\n",
    "        numba.types.int64,\n",
    "        numba.types.float64,\n",
    "    ),\n",
    "    nopython=True, nogil=True\n",
    ")\n",
    "def tree_shap_recursive(children_left, children_right, children_default, features, thresholds, values, node_sample_weight,\n",
    "                        x, x_missing, phi, node_index, unique_depth, parent_feature_indexes,\n",
    "                        parent_zero_fractions, parent_one_fractions, parent_pweights, parent_zero_fraction,\n",
    "                        parent_one_fraction, parent_feature_index, condition, condition_feature, condition_fraction):\n",
    "\n",
    "    # stop if we have no weight coming down to us\n",
    "    if condition_fraction == 0:\n",
    "        return\n",
    "\n",
    "    # extend the unique path\n",
    "    feature_indexes = parent_feature_indexes[unique_depth + 1:]\n",
    "    feature_indexes[:unique_depth + 1] = parent_feature_indexes[:unique_depth + 1]\n",
    "    zero_fractions = parent_zero_fractions[unique_depth + 1:]\n",
    "    zero_fractions[:unique_depth + 1] = parent_zero_fractions[:unique_depth + 1]\n",
    "    one_fractions = parent_one_fractions[unique_depth + 1:]\n",
    "    one_fractions[:unique_depth + 1] = parent_one_fractions[:unique_depth + 1]\n",
    "    pweights = parent_pweights[unique_depth + 1:]\n",
    "    pweights[:unique_depth + 1] = parent_pweights[:unique_depth + 1]\n",
    "\n",
    "    if condition == 0 or condition_feature != parent_feature_index:\n",
    "        extend_path(\n",
    "            feature_indexes, zero_fractions, one_fractions, pweights,\n",
    "            unique_depth, parent_zero_fraction, parent_one_fraction, parent_feature_index\n",
    "        )\n",
    "\n",
    "    split_index = features[node_index]\n",
    "\n",
    "    # leaf node\n",
    "    if children_right[node_index] == -1:\n",
    "        for i in range(1, unique_depth+1):\n",
    "            w = unwound_path_sum(feature_indexes, zero_fractions, one_fractions, pweights, unique_depth, i)\n",
    "            phi[feature_indexes[i]] += w * (one_fractions[i] - zero_fractions[i]) * values[node_index] * condition_fraction\n",
    "\n",
    "    # internal node\n",
    "    else:\n",
    "        # find which branch is \"hot\" (meaning x would follow it)\n",
    "        hot_index = 0\n",
    "        cleft = children_left[node_index]\n",
    "        cright = children_right[node_index]\n",
    "        if x_missing[split_index] == 1:\n",
    "            hot_index = children_default[node_index]\n",
    "        elif x[split_index] < thresholds[node_index]:\n",
    "            hot_index = cleft\n",
    "        else:\n",
    "            hot_index = cright\n",
    "        cold_index = (cright if hot_index == cleft else cleft)\n",
    "        w = node_sample_weight[node_index]\n",
    "        hot_zero_fraction = node_sample_weight[hot_index] / w\n",
    "        cold_zero_fraction = node_sample_weight[cold_index] / w\n",
    "        incoming_zero_fraction = 1\n",
    "        incoming_one_fraction = 1\n",
    "\n",
    "        # see if we have already split on this feature,\n",
    "        # if so we undo that split so we can redo it for this node\n",
    "        path_index = 0\n",
    "        while (path_index <= unique_depth):\n",
    "            if feature_indexes[path_index] == split_index:\n",
    "                break\n",
    "            path_index += 1\n",
    "\n",
    "        if path_index != unique_depth + 1:\n",
    "            incoming_zero_fraction = zero_fractions[path_index]\n",
    "            incoming_one_fraction = one_fractions[path_index]\n",
    "            unwind_path(feature_indexes, zero_fractions, one_fractions, pweights, unique_depth, path_index)\n",
    "            unique_depth -= 1\n",
    "\n",
    "        # divide up the condition_fraction among the recursive calls\n",
    "        hot_condition_fraction = condition_fraction\n",
    "        cold_condition_fraction = condition_fraction\n",
    "        if condition > 0 and split_index == condition_feature:\n",
    "            cold_condition_fraction = 0;\n",
    "            unique_depth -= 1\n",
    "        elif condition < 0 and split_index == condition_feature:\n",
    "            hot_condition_fraction *= hot_zero_fraction\n",
    "            cold_condition_fraction *= cold_zero_fraction\n",
    "            unique_depth -= 1\n",
    "\n",
    "        tree_shap_recursive(\n",
    "            children_left, children_right, children_default, features, thresholds, values, node_sample_weight,\n",
    "            x, x_missing, phi, hot_index, unique_depth + 1,\n",
    "            feature_indexes, zero_fractions, one_fractions, pweights,\n",
    "            hot_zero_fraction * incoming_zero_fraction, incoming_one_fraction,\n",
    "            split_index, condition, condition_feature, hot_condition_fraction\n",
    "        )\n",
    "\n",
    "        tree_shap_recursive(\n",
    "            children_left, children_right, children_default, features, thresholds, values, node_sample_weight,\n",
    "            x, x_missing, phi, cold_index, unique_depth + 1,\n",
    "            feature_indexes, zero_fractions, one_fractions, pweights,\n",
    "            cold_zero_fraction * incoming_zero_fraction, 0,\n",
    "            split_index, condition, condition_feature, cold_condition_fraction\n",
    "        )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Explain a prediction of one input"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "6\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "array([ 3.86778154e-01, -2.91701989e-04,  2.21392679e-02,  3.21670105e-02,\n",
       "       -2.43900987e-01, -1.72050637e+00, -7.88487600e-03,  1.32586529e+01,\n",
       "       -2.50914664e-02,  1.64892740e-01,  1.16804717e-01, -4.67221730e-01,\n",
       "        9.75267653e+00,  2.25335138e+01])"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x = np.ones(X.shape[1])\n",
    "TreeExplainer(model).shap_values(x)"
   ]
  }
 ],
 "metadata": {
  "anaconda-cloud": {},
  "kernelspec": {
   "display_name": "Python [conda root]",
   "language": "python",
   "name": "conda-root-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
